SHapley Additive exPlanations for models trained to predict players' wages in the FIFA 23 game

Task A

We analyze the random forest model trained in the previous assignment on the FIFA 23 game dataset. The model aims to predict players' wages given their statistics in the game.

We use SHapley Additive exPlanations to compute feature importance.

We also compare the random forest model with the linear model trained in the previous assignment.

Observations in the dataset, such that they have different variables of the highest importance

We have an observation for which the most important values (due to shap library) are Value in euro and Ball control and another one with the most important values are Overall and Country.

Finding such an observation required quite much searching, because for most of observations the highest impact features are Overall and Value in euro.

The first observation:

Point4_shap_im1.png

The second observation:

Point4_shap_im2.png

These results are very similar for dalex library. Although computing Shapley values with dalex took much more time.

The first observation in dalex:

Point4_dalex_im1.png

The second observation in dalex:

Point4_dalex_im2.png

Observations and a variable in the dataset, such that the variable has positive impact on one observation and negative impact on the other one

We observe that the variable Overall has positive impact on ane observation and negative impact on another observation according to shap library.

Finding such an observation was quite easy.

The first observation:

Point5_shap_im1.png

The second observation:

Point5_shap_im2.png

The results are analogical for dalex library.

The first observation in dalex library:

Point5_dalex_im1.png

The second observation in dalex:

Point5_dalex_im2.png

Comparison of Shapley values between the tree model and the linear model

The most important features for the random forest model and the linear model usually differ. The tree model mostly focuses on Overall and Value in euro features, while the linear model often focuses on Stats and Position ratings.

The most important features for the tree model and a chosen observation:

Point7_shap_im1.png

The most important features for the linear model and the same observation:

Point7_shap_im2.png

Task B

Players A and B are symmetric, so the Shapley values S_a and S_b are equal. Since S_a + S_b + S_c = v(A, B, C) = 100, so S_c = 100 - 2 * S_a.

We have 3! S_a = 20 2! + (60 - 20) 1! 1! + (70 - 60) 1! 1! + (100 - 70) * 2! (we add respectively A on the first position, A on the second position after B, A on the second position after C, A on the third position).

Thus 6 * S_a = 40 + 40 + 10 + 60 = 150 and so S_a = 25.

So we have S_a = 25, S_b = 25, S_c = 100 - 2 * S_a = 50.

Appendix

Importing libraries and dataset

In [36]:
# 1. Import libraries

!pip3 install shap
!pip3 install dalex
!pip3 install -q condacolab
import condacolab
condacolab.install()
!conda install -c conda-forge python-kaleido

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly
import kaleido

import pickle
import shap
import dalex as dx

from math import isclose

from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: shap in /usr/local/lib/python3.7/site-packages (0.41.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/site-packages (from shap) (1.7.3)
Requirement already satisfied: numba in /usr/local/lib/python3.7/site-packages (from shap) (0.56.3)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from shap) (1.21.6)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/site-packages (from shap) (1.3.5)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/site-packages (from shap) (1.0.2)
Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/site-packages (from shap) (4.64.0)
Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/site-packages (from shap) (21.3)
Requirement already satisfied: slicer==0.0.7 in /usr/local/lib/python3.7/site-packages (from shap) (0.0.7)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/site-packages (from shap) (2.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging>20.9->shap) (3.0.9)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/site-packages (from numba->shap) (5.0.0)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.7/site-packages (from numba->shap) (0.39.1)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/site-packages (from numba->shap) (65.3.0)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/site-packages (from pandas->shap) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/site-packages (from pandas->shap) (2022.5)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn->shap) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/site-packages (from scikit-learn->shap) (1.2.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->shap) (1.16.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/site-packages (from importlib-metadata->numba->shap) (3.10.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/site-packages (from importlib-metadata->numba->shap) (4.4.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: dalex in /usr/local/lib/python3.7/site-packages (1.5.0)
Requirement already satisfied: tqdm>=4.61.2 in /usr/local/lib/python3.7/site-packages (from dalex) (4.64.0)
Requirement already satisfied: plotly>=5.1.0 in /usr/local/lib/python3.7/site-packages (from dalex) (5.10.0)
Requirement already satisfied: scipy>=1.6.3 in /usr/local/lib/python3.7/site-packages (from dalex) (1.7.3)
Requirement already satisfied: pandas>=1.2.5 in /usr/local/lib/python3.7/site-packages (from dalex) (1.3.5)
Requirement already satisfied: numpy>=1.20.3 in /usr/local/lib/python3.7/site-packages (from dalex) (1.21.6)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/site-packages (from dalex) (65.3.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/site-packages (from pandas>=1.2.5->dalex) (2022.5)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/site-packages (from pandas>=1.2.5->dalex) (2.8.2)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/site-packages (from plotly>=5.1.0->dalex) (8.1.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas>=1.2.5->dalex) (1.16.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
✨🍰✨ Everything looks OK!
Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | done
Solving environment: - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | done

# All requested packages already installed.

Retrieving notices: ...working... done
In [37]:
# 2. Load dataset and models from the previous homework (POINT 1)

with open('X_train.pickle', 'rb') as handle:
    X_train_load = pickle.load(handle)

with open('y_train.pickle', 'rb') as handle:
    y_train_load = pickle.load(handle)

with open('X_test.pickle', 'rb') as handle:
    X_test_load = pickle.load(handle)

with open('y_test.pickle', 'rb') as handle:
    y_test_load = pickle.load(handle)

with open('tree_model.pickle', 'rb') as handle:
    forest_reg_load = pickle.load(handle)

with open('linear_model.pickle', 'rb') as handle:
    linear_model_load = pickle.load(handle)

print(X_train_load)
print(y_train_load)
print(X_test_load)
print(y_train_load)
print(forest_reg_load.predict(X_train_load))
print(linear_model_load.predict(X_train_load))
print(forest_reg_load.predict(X_test_load))
print(linear_model_load.predict(X_test_load))
       Overall  Potential  Value(in Euro)  Age  Height(in cm)  Weight(in kg)  \
1127        76         77         9000000   27            190             85   
6725        68         68         1200000   28            190             87   
4966        70         70         1100000   31            190             93   
1799        75         75         6000000   22            188             84   
4484        70         70         1300000   32            165             63   
...        ...        ...             ...  ...            ...            ...   
11284       64         64          525000   28            183             83   
11964       63         69          700000   23            176             70   
5390        69         69          375000   35            187             85   
860         77         77        10500000   29            176             73   
15795       59         59          220000   29            170             70   

       TotalStats  BaseStats  Release Clause  Weak Foot Rating  ...  \
1127         1808        392               0                 3  ...   
6725         1611        344         1600000                 2  ...   
4966         1594        338         1600000                 3  ...   
1799         1717        364        11400000                 3  ...   
4484         1773        365         2200000                 3  ...   
...           ...        ...             ...               ...  ...   
11284        1643        355          919000                 2  ...   
11964        1625        353               0                 3  ...   
5390         1672        359          750000                 3  ...   
860          2037        430        20000000                 3  ...   
15795        1564        335          363000                 4  ...   

       Best_Position_RM  Best_Position_RW  Best_Position_RWB  \
1127                  0                 0                  0   
6725                  0                 0                  0   
4966                  0                 0                  0   
1799                  0                 0                  0   
4484                  0                 0                  0   
...                 ...               ...                ...   
11284                 0                 0                  0   
11964                 0                 0                  0   
5390                  0                 0                  0   
860                   0                 0                  0   
15795                 0                 0                  0   

       Best_Position_ST  Attacking_Work_Rate_High  Attacking_Work_Rate_Low  \
1127                  1                         0                        0   
6725                  1                         0                        0   
4966                  0                         0                        1   
1799                  0                         0                        0   
4484                  0                         0                        0   
...                 ...                       ...                      ...   
11284                 0                         0                        1   
11964                 0                         0                        0   
5390                  0                         0                        1   
860                   0                         1                        0   
15795                 0                         0                        0   

       Attacking_Work_Rate_Medium  Defensive_Work_Rate_High  \
1127                            1                         0   
6725                            1                         0   
4966                            0                         1   
1799                            1                         0   
4484                            1                         0   
...                           ...                       ...   
11284                           0                         1   
11964                           1                         0   
5390                            0                         1   
860                             0                         0   
15795                           1                         0   

       Defensive_Work_Rate_Low  Defensive_Work_Rate_Medium  
1127                         0                           1  
6725                         1                           0  
4966                         0                           0  
1799                         0                           1  
4484                         1                           0  
...                        ...                         ...  
11284                        0                           0  
11964                        0                           1  
5390                         0                           0  
860                          0                           1  
15795                        0                           1  

[14831 rows x 111 columns]
1127     48000
6725      6000
4966      8000
1799     13000
4484     18000
         ...  
11284     2000
11964     2000
5390      6000
860      46000
15795     5000
Name: Wage(in Euro), Length: 14831, dtype: int64
       Overall  Potential  Value(in Euro)  Age  Height(in cm)  Weight(in kg)  \
10157       65         75         1500000   22            180             79   
3617        72         72         1900000   31            179             77   
4894        70         73         2100000   26            183             73   
2315        74         81         8000000   23            178             72   
2177        74         75         5000000   26            175             71   
...        ...        ...             ...  ...            ...            ...   
9049        66         66          725000   30            170             72   
14757       61         74          800000   21            184             78   
6779        68         76         2500000   23            182             73   
3269        72         72         2300000   30            185             82   
11602       64         64          500000   29            181             75   

       TotalStats  BaseStats  Release Clause  Weak Foot Rating  ...  \
10157        1561        344         3600000                 2  ...   
3617         1987        409         3600000                 4  ...   
4894         1975        410         4600000                 2  ...   
2315         1803        393        16800000                 3  ...   
2177         1854        385               0                 4  ...   
...           ...        ...             ...               ...  ...   
9049         1777        371         1600000                 3  ...   
14757        1485        333         1600000                 3  ...   
6779         1774        380         4099999                 3  ...   
3269         1825        387         5100000                 4  ...   
11602        1580        354          875000                 4  ...   

       Best_Position_RM  Best_Position_RW  Best_Position_RWB  \
10157                 0                 0                  0   
3617                  1                 0                  0   
4894                  0                 0                  0   
2315                  0                 0                  0   
2177                  0                 0                  0   
...                 ...               ...                ...   
9049                  0                 0                  0   
14757                 0                 0                  0   
6779                  0                 0                  0   
3269                  0                 0                  0   
11602                 0                 0                  0   

       Best_Position_ST  Attacking_Work_Rate_High  Attacking_Work_Rate_Low  \
10157                 0                         0                        0   
3617                  0                         0                        0   
4894                  0                         1                        0   
2315                  0                         0                        0   
2177                  1                         1                        0   
...                 ...                       ...                      ...   
9049                  0                         0                        1   
14757                 0                         0                        0   
6779                  0                         0                        0   
3269                  1                         0                        0   
11602                 0                         0                        0   

       Attacking_Work_Rate_Medium  Defensive_Work_Rate_High  \
10157                           1                         0   
3617                            1                         1   
4894                            0                         0   
2315                            1                         0   
2177                            0                         0   
...                           ...                       ...   
9049                            0                         0   
14757                           1                         0   
6779                            1                         0   
3269                            1                         0   
11602                           1                         0   

       Defensive_Work_Rate_Low  Defensive_Work_Rate_Medium  
10157                        0                           1  
3617                         0                           0  
4894                         0                           1  
2315                         0                           1  
2177                         0                           1  
...                        ...                         ...  
9049                         0                           1  
14757                        0                           1  
6779                         0                           1  
3269                         1                           0  
11602                        0                           1  

[3708 rows x 111 columns]
1127     48000
6725      6000
4966      8000
1799     13000
4484     18000
         ...  
11284     2000
11964     2000
5390      6000
860      46000
15795     5000
Name: Wage(in Euro), Length: 14831, dtype: int64
[45130.  5730.  8335. ...  5429. 41640.  4170.]
[36461.14113481  4366.01557268  9049.04479392 ...  3726.18901892
 33882.66952547  3284.73218782]
[ 1992.  12406.5 10253.5 ...  4265.  14385.5  2641. ]
[ 5452.90561019 11907.54997679 10061.40845932 ...  8209.27849936
 10756.15177953  5845.96896926]

Analysys of the model

In [15]:
# 1. Observe predictions of two observations (POINT 2)

observations = X_test_load.sample(2, random_state = 1)

predictions = forest_reg_load.predict(observations)

print(observations)
print(predictions)
      Overall  Potential  Value(in Euro)  Age  Height(in cm)  Weight(in kg)  \
34         87         87        63000000   32            185             76   
8193       67         67          825000   31            183             85   

      TotalStats  BaseStats  Release Clause  Weak Foot Rating  ...  \
34          2140        443       104000000                 4  ...   
8193        1880        384         1600000                 3  ...   

      Best_Position_RM  Best_Position_RW  Best_Position_RWB  Best_Position_ST  \
34                   0                 0                  0                 0   
8193                 0                 0                  0                 0   

      Attacking_Work_Rate_High  Attacking_Work_Rate_Low  \
34                           1                        0   
8193                         0                        0   

      Attacking_Work_Rate_Medium  Defensive_Work_Rate_High  \
34                             0                         1   
8193                           1                         0   

      Defensive_Work_Rate_Low  Defensive_Work_Rate_Medium  
34                          0                           0  
8193                        0                           1  

[2 rows x 111 columns]
[174720.   7040.]
In [16]:
# 2. Calculate Shapley values for selected observations with shap library (POINT 3)

explainer_shap = shap.TreeExplainer(forest_reg_load)
shap_values = explainer_shap.shap_values(observations)
assert(isclose(np.abs(shap_values.sum(1) + explainer_shap.expected_value - predictions).max(), 0.0, abs_tol = 1e-06))
shap.bar_plot(shap_values[0], feature_names = observations.columns, show = False)
plt.savefig('Point3_shap_im_1.png', dpi=300, bbox_inches='tight')
shap.bar_plot(shap_values[1], feature_names = observations.columns, show = False)
plt.savefig('Point3_shap_im_2.png', dpi=300, bbox_inches='tight')
In [17]:
# 3. Calculate Shapley values for selected observations with dalex library (POINT 3)

explainer_dx = dx.Explainer(forest_reg_load, X_test_load, y_test_load, label='default', verbose = False)
shap_values_dx = [explainer_dx.predict_parts(observations.iloc[i], type = 'shap') for i in range(len(observations))]
/usr/local/lib/python3.7/site-packages/sklearn/base.py:451: UserWarning:

X does not have valid feature names, but RandomForestRegressor was fitted with feature names

In [18]:
# 4. Calculate Shapley values for selected observations with dalex library (POINT 3)

plot_id = 3
for shap_value_dx in shap_values_dx:
    fig = shap_value_dx.plot(show = False)
    fig.write_image('Point3_dalex_im' + str(plot_id) + '.png')
    plot_id += 1
In [19]:
# 5. Searching for two observations in the dataset, such that they have different variables of the highest importance (POINT 4)

new_observations = X_test_load.sample(300, random_state = 25)
grid_shap_values = explainer_shap.shap_values(new_observations)
for grid_shap_value in grid_shap_values:
    shap.bar_plot(grid_shap_value, feature_names = new_observations.columns)
In [20]:
# 6. Plotting found observations with different variables of the highest importance (POINT 4)

found_observations = new_observations.iloc[[-20, -14]]
found_shap_values = explainer_shap.shap_values(found_observations)
plot_id = 1
for found_shap_value in found_shap_values:
    shap.bar_plot(found_shap_value, feature_names = new_observations.columns, show = False)
    plt.savefig('Point4_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
    plot_id += 1
In [25]:
# 7. Plotting dalex Shapley values for found observations (POINT 4)

found_shap_values_dx = [explainer_dx.predict_parts(found_observations.iloc[i], type = 'shap') for i in range(len(found_observations))]
In [26]:
plot_id = 1
for found_shap_value_dx in found_shap_values_dx:
    fig = found_shap_value_dx.plot(show = False)
    fig.write_image('Point4_dalex_im' + str(plot_id) + '.png')
    plot_id += 1
In [23]:
# 8. Searching for two observations and a variable in the dataset,
# such that the variable has positive impact on one observation and negative impact on the other one (POINT 5)

new_observations = X_test_load.sample(20, random_state = 1)
grid_shap_values = explainer_shap.shap_values(new_observations)
for grid_shap_value in grid_shap_values:
    shap.bar_plot(grid_shap_value, feature_names = new_observations.columns)
In [24]:

In [33]:
# 9. Show selected value and observations (POINT 5)

found_observations = X_test_load.sample(20, random_state = 1)[:2]
found_shap_values = explainer_shap.shap_values(found_observations)
plot_id = 1
for found_shap_value in found_shap_values:
    shap.bar_plot(found_shap_value, feature_names = found_observations.columns, show = False)
    plt.savefig('Point5_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
    plot_id += 1
In [34]:
# 10. Plotting dalex Shapley values for found observations (POINT 5)

found_shap_values_dx = [explainer_dx.predict_parts(found_observations.iloc[i], type = 'shap') for i in range(len(found_observations))]
In [35]:
plot_id = 1
for found_shap_value_dx in found_shap_values_dx:
    fig = found_shap_value_dx.plot(show = False)
    fig.write_image('Point5_dalex_im' + str(plot_id) + '.png')
    plot_id += 1
In [45]:
# 11. Searching for an observation such that SHAP attributions are different
# between the tree model and the linear model (POINT 7)

explainer_linear_shap = shap.Explainer(linear_model_load, X_train_load)
shap_values = explainer_shap.shap_values(observations)

new_observations = X_test_load.sample(20, random_state = 10)
grid_shap_values = explainer_shap.shap_values(new_observations)
grid_linear_shap_values = explainer_linear_shap(new_observations).values
for i in range(len(new_observations)):
    grid_shap_value = grid_shap_values[i]
    grid_linear_shap_value = grid_linear_shap_values[i]
    shap.bar_plot(grid_shap_value, feature_names = new_observations.columns)
    shap.bar_plot(grid_linear_shap_value, feature_names = new_observations.columns)
In [48]:
# 12. Show selected observations and plot Shapley values (POINT 5)

new_observations = X_test_load.sample(1, random_state = 10)
grid_shap_values = explainer_shap.shap_values(new_observations)
grid_linear_shap_values = explainer_linear_shap(new_observations).values
plot_id = 1
for i in range(len(new_observations)):
    grid_shap_value = grid_shap_values[i]
    grid_linear_shap_value = grid_linear_shap_values[i]
    shap.bar_plot(grid_shap_value, feature_names = new_observations.columns, show = False)
    plt.savefig('Point7_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
    plot_id += 1
    shap.bar_plot(grid_linear_shap_value, feature_names = new_observations.columns, show = False)
    plt.savefig('Point7_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
    plot_id += 1